"""
conda activate rebuttal
CUDA_VISIBLE_DEVICES=1 python eval_FID.py
"""


dataset_dir = "./../dataset"
test_id = "rebuttal/video" #TODO
gt_id = "GT"


import sys
import cv2
import numpy as np
from PIL import Image
import json
from tqdm import tqdm
import os 
import PIL

H, W = 320, 576
num_frames = 14

#extract frames
def extract_frames(file_id):
    extract_dir = os.path.join(dataset_dir, file_id)
    save_dir = os.path.join(dataset_dir, file_id+"_extract")
    if os.path.exists(save_dir):
        print(save_dir+" exists!!")
        return save_dir
    os.makedirs(save_dir, exist_ok=True)

    file_list = os.listdir(extract_dir)
    print(len(file_list))
    assert(len(file_list)==329 or len(file_list)==330)
    
    for i, f in enumerate(file_list):
        if not f.endswith(".mp4"):
            continue
        print("reading:", i, f)
        videopath = os.path.join(extract_dir, f)
        cam = cv2.VideoCapture(videopath)
        ctr = 0
        while ctr < num_frames:
            if True:
                _, frame = cam.read()
    
                savepath = os.path.join(save_dir, os.path.join(f.replace(".mp4", "_")+str(ctr) + ".png"))
                frame = cv2.resize(frame, (W, H))
                cv2.imwrite(savepath, frame)
                ctr += 1 
            else:
                break
        cam.release()
    return save_dir
    
    
save_dir1 = extract_frames(test_id)
save_dir2 = extract_frames(gt_id)

print(save_dir1, save_dir2)

#evaluate FID
os.system("python -m pytorch_fid "+save_dir1+" "+save_dir2)
